from gd import *
from .helpers import np_expit, softplus, softplus_thresholded, multiple_logpdfs_one_sigma, mvn_logpdf, hardplus
import autograd.numpy as np
import autograd.scipy.stats.beta as beta
import autograd.scipy.stats.norm as norm
from .ContinuousFactor import ContinuousFactor
import random


class ContinuousModel:

    def __init__(self, data, num_subtrees=0, subtree_loadings=None, max_depth=3, seed=None):
        self.data = data
        self.num_subtrees = num_subtrees
        self.subtree_loadings = subtree_loadings
        self.max_depth = max_depth
        self.N = data.shape[0]
        self.D = data.shape[1]
        self.seed = seed
        self.inference = inference
        
        np.random.seed(self.seed)
        random.seed(self.seed)
        
        self.factors = []
        self.iter_counter = 0
        self.prev_params = None
        self.prev_objective = None
        self.min_objective_diff = 9999999999999

        self.nn = True
        self.sigma_scale = 0.1
        self.beta_hyper = (10, 10)

    def configure(self):
        if self.factors == []:
            root = self.new_root()
            if self.num_subtrees > 0:
                subtree_factors = root.create_subtrees(self.num_subtrees, self.subtree_loadings)
                self.factors.extend(subtree_factors)
            self.depth = 1
            cur_depth = self.depth
            while cur_depth <= self.max_depth - 1:
                new_factors = []
                for factor in self.factors:
                    if len(factor.children) == 0:
                        new_factors.extend(factor.split())
                cur_depth += 1
                self.factors.extend(new_factors)
        else:
            self.set_all_fixed()
            self.depth += 1
            new_factors = []
            for factor in self.factors:
                if len(factor.children) == 0:
                    new_factors.extend(factor.split())

            self.factors.extend(new_factors)

        self.n_latent_factors = self.get_n_latent_factors()
        
        # could also make this random
        self.latent_z = np.tile(
            np.array(0.5 * np.ones(self.N)), self.n_latent_factors)
        
        # could also make this random
        self.latent_mu = np.array(self.data[random.sample(range(self.N), self.n_latent_factors), :]).flatten()
        
        # could also make this random
        self.latent_sigma = np.array([0.1])
        
        self.params = [self.latent_z, self.latent_mu, self.latent_sigma]
        self.params = np.hstack(self.params)

        self.initialize_zs()
        self.initialize_mus()

    def initialize_zs(self):
        for i, factor in enumerate(self.latent_factors()):
            factor.set_z(self.latent_z.reshape((self.N, self.get_n_latent_factors()))[:, i])

    def initialize_mus(self):
        for i, factor in enumerate(self.latent_factors()):
            factor.set_mu(self.latent_mu.reshape((self.get_n_latent_factors(), self.D))[i, :].T)

    def set_latents(self, mus, zs, sigma):
        for i, factor in enumerate(self.latent_factors()):
            if self.inference == 'full':
                factor.set_mu(mus[i])
            factor.set_z(zs[i])
        self.latent_sigma = sigma

    def set_all_fixed(self):
        for factor in self.factors:
            factor.fixed = True

    def new_root(self):
        factor = self.blank_factor(None, True, 0)
        self.factors.append(factor)
        return factor

    def blank_factor(self, parent, orientation, depth, subtree=False, subtree_loading=None):
        return ContinuousFactor(parent, self, orientation, depth, subtree=subtree, subtree_loading=subtree_loading)

    def unpack_params(self, params):
        latent_zs = []
        latent_mus = []
        offset_mu = self.n_latent_factors * self.N
        for factor_index in range(self.n_latent_factors):
            latent_zs.append(
                np_expit(params[factor_index * self.N:(factor_index + 1) * self.N]))
            latent_mus.append(
                params[offset_mu + factor_index * self.D:offset_mu + (factor_index + 1) * self.D])

        latent_sigma = params[-1]

        if self.nn:
            return softplus_thresholded(np.array(latent_mus)), np.array(latent_zs), latent_sigma
        else:
            return np.array(latent_mus), np.array(latent_zs), latent_sigma

    def objective(self, params, t):
        latent_mus, latent_zs, latent_sigma = self.unpack_params(params)
        self.set_latents(latent_mus, latent_zs, latent_sigma)

        view1_gaussian = np.sum(self.gaussian_likelihood())
        z_likelihood = self.z_likelihood()

        model1_objective = view1_gaussian + z_likelihood + self.mu_prior() + self.sigma_prior()

        print('VIEW 1 Gaussian:', view1_gaussian)
        print('SIGMA1', self.latent_sigma)
        print('VIEW 1 OBJECTIVE:', model1_objective)
        print('Z likelihood', z_likelihood)

        if self.prev_objective is None:
            self.prev_objective == model1_objective
        else:  # check for convergence
            objective_diff = abs(model1_objective - self.prev_objective)
            self.min_objective_diff = min(self.min_objective_diff, objective_diff)
            self.prev_objective = model1_objective

        return -1 * (model1_objective)

    def project_data(self):
        partials = np.array(
            [factor.partial() for factor in self.factors])
        projection = np.sum(partials, axis=0)
        return projection

    def gaussian_likelihood(self):
        means = self.project_data()
        log_probabilities = multiple_logpdfs_one_sigma(self.data, means, self.latent_sigma)
        return log_probabilities

    def callback(self, params, t, g):
        # print(params)
        if self.prev_params is not None:
            print('param diff:', '\n', np.mean(abs(params - self.prev_params)))
        mus, zs, sigma = self.unpack_params(params)
        print(zs[0], sigma)
        self.set_latents(mus, zs, sigma)

        self.iter_counter += 1
        # print('gradient:', '\n', g)
        self.prev_params = params

    def z_likelihood(self):
        zs = np.array([factor.z for factor in self.latent_factors() if factor.depth > 0])
        prior = np.sum(beta.logpdf(zs, *self.beta_hyper))
        return prior

    def mu_prior(self):
        zero_means = np.zeros((self.n_latent_factors, self.D))
        factor_means = [factor.mu for factor in self.latent_factors()]
        log_probabilities = multiple_logpdfs_one_sigma(np.array(factor_means), zero_means, 1)
        prior = np.sum(log_probabilities)
        return prior

    def sigma_prior(self):
        return norm.logpdf(self.latent_sigma, 0, self.sigma_scale)

    def latent_factors(self):
        return [f for f in self.factors if not f.fixed]

    def get_n_latent_factors(self):
        return len([f for f in self.factors if not f.fixed])

    def assemble_trees(self):
        trees = []
        if self.num_subtrees > 0:
            for i in range(self.num_subtrees):
                tree_factors = []
                cur_nodes = [self.factors[i + 1]]
                while cur_nodes != []:
                    cur_node = cur_nodes.pop(0)
                    tree_factors.append(cur_node)
                    if cur_node.children != []:
                        cur_nodes.extend(cur_node.children)
                trees.append(tree_factors)
        else:
            tree_factors = []
            cur_nodes = [self.factors[0]]
            while cur_nodes != []:
                cur_node = cur_nodes.pop(0)
                tree_factors.append(cur_node)
                if cur_node.children != []:
                    cur_nodes.extend(cur_node.children)
            trees.append(tree_factors)
        return trees

    def update(self, x):
        x[-1] = hardplus(x[-1]) 
        if x[-1] == 0:
            x[-1] = 0.001
        return x
